import torch
from einops import repeat

# repeat: 张量扩张，用于对张量的某一个维度进行复制，可用于替换pytorch中的repeat

# 维度复制(新增维度)
x = torch.randn(3, 4)
y = repeat(x, "i j -> i j k", k=5)  # torch.Size([3, 4, 5])

# 维度复制(沿某维度扩展)
x = torch.randn(3, 4)
y = repeat(x, "i j ->i (j k)", k=2)  # torch.Size([3, 8])

pass
